#!/usr/bin/env python2

import os
import sys
import socket
import time
import subprocess
import fcntl
import types
import errno
import getopt
import random
import re

from config import Config
from utils import Exp, _dwarn, _derror

__TYPE_DISK__ = 1
__TYPE_NODE__ = 2
__TYPE_RACK__ = 3
__TYPE_SITE__ = 5
__TYPE_LIST__ = 5

config = Config()

def name2map(name, hosts=None):
    if '/' not in name:
        name = name + '/0'
    r = re.compile("(\d+)\.(\d+)\.(\d+)\.(\d+)\/(\d+)")
    m = r.match(name)
    if (m == None):
        return _name2map(name)
 
    r = re.compile("([^/]+)\/(\d+)")
    m = r.match(name)
    if (m == None):
        print("name : %s" %  (name))
        raise Exp(errno.EINVAL, str(errno.EINVAL))
    else:
        lst = m.groups()
        ip = lst[0]
        disk = lst[1]
        hostname = config.ip2hostname_nohosts(ip, hosts) 
        name_new = '/'.join([hostname, disk])
        return _name2map(name_new)

def name2site(name):
    dmap = name2map(name)
    return dmap[0]

def _name2map(name):
    assert(name.count('.') <= 2)
    if '/' not in name:
        name = name + '/0'

    r = re.compile("([^.]+)\.([^.]+)\.([^.]+)\/(\d+)")
    m = r.match(name)
    if (m == None):
        r = re.compile("([^.]+)\.([^.]+)\/(\d+)")
        m = r.match(name)
        if (m == None):
            r = re.compile("([^/]+)\/(\d+)")
            m = r.match(name)
            if (m == None):
                print("name : %s" %  (name))
                raise Exp(errno.EINVAL, str(errno.EINVAL))
            else:
                lst = m.groups()
                return ["default", 'default', lst[0], lst[1]]
        else:
            lst = m.groups()
            return ['default', lst[0], lst[1], lst[2]]

    else:
        lst = m.groups()
        return [lst[0], lst[1], lst[2], lst[3]]

def map2name(lst, hosts=None):
    node = ".".join(lst[:-1])
    p = 'default.'

    for i in range(2):
        if node.startswith(p):
            node = node[len(p):]
        else:
            break

    if config.nohosts:
        ip = config.hostname2ip_nohosts(node, hosts)
    else:
        ip = node

    disk = "/".join([ip, lst[-1]])
    return disk

class Entry:
    def __init__(self, _id, name, entry_type, parent, table):
        self.type = entry_type
        self.child = []
        self.parent = parent
        self.table = table
        self.name = name
        self.id = _id

    def sort(self):
        self.child = sorted(self.child, cmp=lambda x,y:cmp(len(x.child), len(y.child)), reverse=True)

    def add(self, ent):
        if (self.type != __TYPE_LIST__):
            for i in self.child:
                if (i.name == ent.name):
                    raise Exp(errno.EEXIST, "%s exist" % (i.name))

        self.child.append(ent)

    def get_inchild(self):
        ent = self.child.pop(0)
        self.child.append(ent)
        return ent

    def __str__(self):
        lst = []
        for i in self.child:
            lst.append(i.id)
        return self.id + ":" + str(lst)

class DiskMap:
    def __init__(self):
        self.cluster = Entry("", "", __TYPE_LIST__, None, None)
        self.site = Entry("", "", __TYPE_LIST__, None, None)
        self.rack = Entry("", "", __TYPE_LIST__, None, None)
        self.node = Entry("", "", __TYPE_LIST__, None, None)
        self.disk = Entry("", "", __TYPE_LIST__, None, None)
        self.map = {}

    def __find(self, lst, _id):
        for i in lst.child:
            #print ("cmp %s %s", i.id, _id)
            if (i.id == _id):
                return i
        raise Exp(errno.ENOENT, "%s not exist" % (_id))

    def __exist(self, lst, _id):
        for i in lst.child:
            #print ("cmp %s %s exist %u" % (i.id, _id, len(i.child)))
            if (i.id == _id):
                return (True, len(i.child))

        return (False, 0)

    def __get_inlist(self, lst):
        ent = lst.child.pop(0)
        lst.child.append(ent)
        return ent

    def __get_inchild(self, lst, skip):
        array = []
        retry = 0
        while (retry < len(lst.child)):
            ent = self.__get_inlist(lst)
            exist, count = self.__exist(skip, ent.id)
            if (exist):
                retry += 1
                if (len(lst.child) > len(skip.child) or lst.type == __TYPE_NODE__):
                    #_dmsg("%s exist" % (ent.id))
                    continue
                else:
                    array.append((ent, count))
                    #_dmsg("%s exist, count %u" % (ent.id, count))
                    continue

            return ent

        if (len(array)):
            array1 = sorted(array, cmp=lambda x,y:cmp(x[1], y[1]), reverse=False)
            #print ("min %u max %u count %u" % (array[0][1], array[-1][1], len(array)))
            return array1[0][0]
        else:
            return None

    def __get_bydisk(self, count, skip):
        """
        if (len(self.disk.child) < count+ len(skip.disk.child)):
            #print ("%u %u" % (len(self.disk.child), count+ len(skip.disk.child)))
            raise Exp(errno.ENOSPC, "no space left")
            """

        if (len(self.disk.child) < count + len(skip.disk.child)):
            count = len(self.disk.child) - len(skip.disk.child);
            if (count <= 0):
                #print ("disk %u %u %u" % (len(self.disk.child), len(skip.disk.child), count))
                raise Exp(errno.ENOSPC, "no space left")

        if (len(self.site.child) > 1):
            #print ("site child %u %u" % (len(self.site.child), count))
            raise Exp(errno.ENOSPC, "got multi rack, no space left")

        if (len(self.site.child) > 0 and len(self.site.child[0].child) > 1):
            #print ("site child %u %u %u" % (len(self.site.child), count, len(self.site.child[0].child)))
            raise Exp(errno.ENOSPC, "got multi node, no space left")

        _dwarn("allocate by disk %d" % (count))
        res = []
        for i in range(count):
            while (1):
                disk = self.__get_inlist(self.disk)
                if (self.__exist(skip.disk, disk.id)[0]):
                    #print ("***** skip disk %s *****" %(disk.id))
                    continue

                break

            res.append(disk.id)

        return res
    
    def __get_bynode(self, count, skip):
        if (len(self.node.child) < count + len(skip.node.child)):
            newcount = len(self.node.child) -  len(skip.node.child)
            if (newcount > 0):
                res = self.__get_bynode(newcount, skip)
                for i in res:
                    skip.add(i)
            else:
                res = []

            try:
                return res + self.__get_bydisk(count - newcount, skip)
            except Exp, e:
                return res

        #_dwarn("allocate by node %d" % (count))

        res = []
        for i in range(count):
            while (1):
                node = self.__get_inlist(self.node)
                if (self.__exist(skip.node, node.id)[0]):
                    #print ("***** skip node %s *****" %(node.id))
                    continue

                break

            res.append(node.get_inchild().id)

        return res

    def get_in(self, _id, count, _skip=None):
        try:
            site = self.__find(self.site, _id)
            return self.__get_insite(site, count, _skip)
        except Exp, e:
            try:
                rack = self.__find(self.rack, _id)
                return self.__get_inrack(rack, count, _skip)
            except Exp, e:
                try:
                    node = self.__find(self.node, _id)
                    _derror("move to %s is not support" % (_id))
                    exit(errno.EINVAL)
                except Exp, e:
                    try:
                        disk = self.__find(self.disk, _id)
                        return [_id]
                    except Exp, e:
                        _derror("move to %s not found" % (_id))
                        exit(errno.ENOENT)

    def get_inrack(self, name, count, _skip=None):
        try:
            rack = self.__find(self.rack, name)
        except Exp, e:
            _derror("rack %s not available" % (name))
            if (e.errno == errno.ENOENT):
                rack = self.rack.child[0]
            else:
                raise

        return self.__get_inrack(rack, count, _skip)

    def __get_inrack(self, rack, count, _skip=None):
        skip = self.skip(_skip)

        #print ("name %s %s %s" % (rack.id, rack.name, name))
        #print ("rack : " + str(rack))

        res = []
        for i in range(count):
            retry = 0;
            while (1):
                node = self.__get_inchild(rack, skip.node);
                if (node == None):
                    raise Exp(errno.ENOSPC, "no node available")

                #print ("node : " + str(node))

                disk = self.__get_inchild(node, skip.disk);
                if (disk == None):
                    _derror("no node available @ %s, skip %s retry %u" % (str(node), str(skip.disk), retry))
                    if (retry < len(rack.child)):
                        retry += 1
                        continue

                    if (len(res) == 0):
                        raise Exp(errno.ENOSPC, "no disk available")
                    else:
                        break
                else:

                    #print ("disk : " + str(disk))
                    skip.add(disk.id)
                    res.append(disk.id)
                    break

        return res

    def get_insite(self, name, count, _skip=None):
        try:
            site = self.__find(self.site, name)
        except Exp, e:
            _derror("site %s not available" % (name))
            if (e.errno == errno.ENOENT):
                site = self.site.child[0]
            else:
                raise

        #print ("name %s %s %s" % (rack.id, rack.name, name))
        #print ("rack : " + str(rack))
        return self.__get_insite(site, count, _skip)

    def __get_insite(self, site, count, _skip=None):
        skip = self.skip(_skip)

        res = []
        for i in range(count):
            retry = 0;
            while (1):
                rack = self.__get_inchild(site, skip.rack);
                if (rack == None):
                    raise Exp(errno.ENOSPC, "no rack available")

                node = self.__get_inchild(rack, skip.node);
                if (node == None):
                    _derror("no node available @ %s, skip %s" % (str(node), str(skip.node)))
                    continue

                #print ("node : " + str(node))

                disk = self.__get_inchild(node, skip.disk);
                if (disk == None):
                    _derror("no node available @ %s, skip %s retry %u" % (str(node), str(skip.disk), retry))
                    if (retry < len(rack.child)):
                        retry += 1
                        continue

                    if (len(res) == 0):
                        raise Exp(errno.ENOSPC, "no disk available")
                    else:
                        break
                else:

                    #print ("disk : " + str(disk))
                    skip.add(disk.id)
                    res.append(disk.id)
                    break

        return res

    def __get_count(self, count, skip):
        diff = len(self.disk.child) - len(skip.disk.child)
        left = count if (count < diff) else diff
        return left

    def __get_byrack(self, count, skip):
        res = []
        for i in range(count):
            while (1):
                rack = self.site.get_inchild()
                if (self.__exist(skip.rack, rack.id)[0]):
                    #print ("***** skip rack %s *****" %(rack.id))
                    continue

                break

            node = rack.get_inchild()
            res.append(node.get_inchild().id)

        return res

    def get(self, count, _skip=None, downgrade=False):
        skip = self.skip(_skip);

        res = []
        for i in range(count):
            retry = 0
            while (1):
                site = self.__get_inchild(self.cluster, skip.site);
                if (site == None):
                    raise Exp(errno.ENOSPC, "no site available")

                #print ("get site %s" % (str(site)))

                rack = self.__get_inchild(site, skip.rack);
                if (rack == None):
                    _derror("no rack available @ %s, skip %s" % (str(node), str(skip.rack)))
                    continue

                node = self.__get_inchild(rack, skip.node);
                if (node == None):
                    _derror("no node available @ %s, skip %s" % (str(node), str(skip.node)))
                    continue

                disk = self.__get_inchild(node, skip.disk);
                if (disk == None):
                    _derror("no disk available @ %s, skip %s, retry %u" % (str(node), str(skip.disk), retry))
                    if (retry < len(self.node.child)):
                        retry += 1
                        continue

                    if (len(res) == 0):
                        raise Exp(errno.ENOSPC, "no disk available")
                    else:
                        break
                else:
                    skip.add(disk.id)
                    res.append(disk.id)
                    break

        return res

    def __add(self, parent, lst, _id, name, _type):
        try:
            ent = self.__find(parent, _id)
        except Exp, e:
            if (e.errno == errno.ENOENT):
                ent = Entry(_id, name, _type, parent, lst)
                parent.add(ent)
                lst.add(ent)
            else:
                raise
        return ent

    def add(self, name):
        lst = name2map(name)

        site = self.__add(self.cluster, self.site, lst[0], lst[0], __TYPE_SITE__)
        rack = self.__add(site, self.rack, lst[0] + '.' + lst[1], lst[1], __TYPE_RACK__)
        node = self.__add(rack, self.node, lst[0] + '.' + lst[1] + '.' + lst[2], lst[2], __TYPE_NODE__)
        disk = self.__add(node, self.disk, name, lst[3], __TYPE_DISK__)
        self.map[name] = disk

    def delete(self, name):
        disk = name
        host = None
        rack = None
        site = None
        cluster = None

        for i in range(len(self.disk.child)):
            if name == self.disk.child[i].id:
                host = self.disk.child[i].parent
                rack = host.parent
                site = rack.parent
                cluster = site.parent
                del self.disk.child[i]
                break

        for i in range(len(host.child)):
            if name == host.child[i].id:
                del host.child[i]
                break

        if (len(host.child) == 0):
            for i in range(len(rack.child)):
                if host.id == rack.child[i].id:
                    del rack.child[i]
                    break

        if (len(rack.child) == 0):
            for i in range(len(site.child)):
                if rack.id == site.child[i].id:
                    del site.child[i]
                    break

        if (len(site.child) == 0):
            for i in range(len(cluster.child)):
                if site.id == cluster.child[i].id:
                    del cluster.child[i]
                    break

        del self.map[name]

    def addmulti(self, lst):
        for i in lst:
            self.add(i)

    def skip(self, _skip):
        skip = DiskMap()

        if (_skip):
            for i in _skip:
                if (i in self.map.keys()):
                    skip.add(i)

        return skip

    def dump(self, name=""):
        print ("dump @ %s --> %s" % (name, str(self.list())))
        for s in self.cluster.child:
            print ("site :%s" % (s.name))
            for r in s.child:
                print ("    rack :%s" % (r.name))
                for h in r.child:
                    print ("         host :%s" % (h.name))
                    for d in h.child:
                        print ("            disk :%s(%s)" % (d.name, d.id))

    def site_count(self):
        return len(self.site.child)

    def rack_count(self):
        return len(self.rack.child)

    def node_count(self):
        return len(self.node.child)

    def disk_count(self):
        return len(self.disk.child)

    def list(self):
        lst = []
        for (k, v) in self.map.items():
            lst.append(k)

        return lst

    def get_min_site(self):
        self.site.sort();
        #print ("max %s child %u" % (self.site.child[0].id, len(self.site.child[0].child)))
        #print ("min %s child %u" % (self.site.child[-1].id, len(self.site.child[-1].child)))
        return self.site.child[-1]

    def get_max_site(self):
        self.site.sort();
        #print ("max %s child %u" % (self.site.child[0].id, len(self.site.child[0].child)))
        #print ("min %s child %u" % (self.site.child[-1].id, len(self.site.child[-1].child)))
        return self.site.child[0]

    def get_min_rack(self, skip=None):
        self.rack.sort();
        #print ("max %s child %u" % (self.site.child[0].id, len(self.site.child[0].child)))
        #print ("min %s child %u" % (self.site.child[-1].id, len(self.site.child[-1].child)))
        if skip is None:
            return self.rack.child[-1]

        j = -1
        for x in range(len(self.rack.child)):
            if (self.rack.child[j].id in skip):
                j = j - 1
                continue
            return self.rack.child[j]

        return None


    def get_max_rack(self):
        self.rack.sort();
        #print ("max %s child %u" % (self.site.child[0].id, len(self.site.child[0].child)))
        #print ("min %s child %u" % (self.site.child[-1].id, len(self.site.child[-1].child)))
        return self.rack.child[0]

    def get_min_node(self):
        self.node.sort();
        #print ("max %s child %u" % (self.site.child[0].id, len(self.site.child[0].child)))
        #print ("min %s child %u" % (self.site.child[-1].id, len(self.site.child[-1].child)))
        return self.rack.child[-1]

    def get_max_node(self):
        self.node.sort();
        #print ("max %s child %u" % (self.site.child[0].id, len(self.site.child[0].child)))
        #print ("min %s child %u" % (self.site.child[-1].id, len(self.site.child[-1].child)))
        return self.node.child[0]

def main():
    _map = DiskMap()

    _map.add("d1.r1.h1/0")
    _map.add("d1.r1.h1/1")
    _map.add("d1.r1.h1/2")
    _map.add("d1.r1.h1/3")
    _map.add("d1.r1.h1/4")
    _map.add("d1.r1.h1/5")
    _map.add("d1.r1.h1/6")
    _map.add("d1.r1.h1/7")
    _map.add("d1.r1.h1/8")
    _map.add("d1.r1.h1/9")
    _map.add("d1.r1.h1/10")

    _map.add("d1.r1.h2/0")
    _map.add("d1.r1.h2/1")
    _map.add("d1.r1.h2/2")
    _map.add("d1.r1.h2/3")

    _map.add("d1.r1.h3/0")
    _map.add("d1.r1.h3/1")
    _map.add("d1.r1.h3/2")
    _map.add("d1.r1.h3/3")

    _map.add("d1.r1.h4/0")
    _map.add("d1.r1.h4/1")
    _map.add("d1.r1.h4/2")
    _map.add("d1.r1.h4/3")

    _map.add("d1.r3.h1/0")
    _map.add("d1.r3.h1/1")
    _map.add("d1.r3.h1/2")
    _map.add("d1.r3.h1/3")

    _map.add("d2.r2.h1/0")
    _map.add("d2.r2.h1/1")
    _map.add("d2.r2.h1/2")
    _map.add("d2.r2.h1/3")
    _map.add("d2.r2.h2/0")
    _map.add("d2.r2.h2/1")
    _map.add("d2.r2.h2/2")
    _map.add("d2.r2.h2/3")

    _map.add("d2.r3.h3/0")
    _map.add("d2.r3.h3/1")
    _map.add("d2.r3.h3/2")
    _map.add("d2.r3.h3/3")
    _map.add("d2.r3.h4/0")
    _map.add("d2.r3.h4/1")
    _map.add("d2.r3.h4/2")
    _map.add("d2.r3.h4/3")

    print((_map.list()))

    _map.dump()

    print("test by rack 8:"+ str(_map.get(8)))
    print("test by rack 9:"+ str(_map.get(9)))
    print("test by rack skip 4:"+ str(_map.get(4, ["d1.r1.h1/0", "d1.r1.h1/2"])))
    print("test by rack skip 5:"+ str(_map.get(5, ["d1.r1.h1/1", "d1.r1.h1/3"])))
    print("test by rack skip 6:"+ str(_map.get(6, ["d1.r1.h1/1", "d1.r1.h1/4"])))
    print("test by rack skip 7:"+ str(_map.get(7, ["d1.r1.h1/1", "d1.r1.h1/5"])))

    print("test in rack 4:"+ str(_map.get_inrack("d2.r2", 4)))
    print("test in rack 5:"+ str(_map.get_inrack("d2.r2", 5)))
    print("test in rack 6:"+ str(_map.get_inrack("d2.r2", 6)))
    print("test in rack 7:"+ str(_map.get_inrack("d2.r2", 7)))
    print("test in rack 8:"+ str(_map.get_inrack("d2.r2", 8)))
    print("test in rack 9:"+ str(_map.get_inrack("d2.r2", 9)))

    print("test in rack skip2:" + str(_map.get_inrack("d1.r1", 2, ["d1.r1.h1/0", "d1.r1.h1/2"])))
    print("test in rack skip3:" + str(_map.get_inrack("d2.r2", 3, ["d1.r1.h1/0", "d1.r1.h1/2"])))
    print("test in rack skip4:" + str(_map.get_inrack("d1.r3", 4, ["d1.r1.h1/0", "d1.r1.h1/2"])))

def _name2map_test():
    names = ['192.168.1.106/1', 'a', 'a.', 'a.b',
            'a.b.c', 'a.b.c.', 'a/0', 'a.b/0',
            'a.b./0', 'a.b.c/0', '/0', '0']
    
    for name in names:
        try:
            lst = _name2map(name)
            print 'ok', name, lst
        except Exception, e:
            print 'except', str(e), name

if __name__ == '__main__':
    print name2map('192.168.1.106/1')
    print name2map('192.168.1.104/0')
    print _name2map_test()
    #print map2name(['d1', 'r1', '192.168.1.106', '0'])
    #main()
